{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\\MNIST\\raw\\train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100.0%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data\\MNIST\\raw\\train-images-idx3-ubyte.gz to ./data\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\\MNIST\\raw\\train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100.0%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data\\MNIST\\raw\\train-labels-idx1-ubyte.gz to ./data\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100.0%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to ./data\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100.0%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to ./data\\MNIST\\raw\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# 禁用PyTorch的梯度跟踪\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# 设备配置（仅使用CPU）\n",
    "device = torch.device('cpu')\n",
    "\n",
    "# 数据预处理（转换为numpy数组）\n",
    "def to_numpy(dataset):\n",
    "    data = dataset.data.numpy().astype(np.float32) / 255.0\n",
    "    labels = dataset.targets.numpy()\n",
    "    one_hot = np.zeros((labels.size, 10), dtype=np.float32)\n",
    "    one_hot[np.arange(labels.size), labels] = 1.0\n",
    "    return data.reshape(-1, 28*28), one_hot\n",
    "\n",
    "# 加载数据集\n",
    "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())\n",
    "test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())\n",
    "\n",
    "# 转换为numpy数组\n",
    "train_data, train_labels = to_numpy(train_dataset)\n",
    "test_data, test_labels = to_numpy(test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 手动实现全连接层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearLayer:\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        self.W = np.random.randn(input_dim, output_dim) * 0.01\n",
    "        self.b = np.zeros((1, output_dim))\n",
    "        self.cache = None\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = np.dot(x, self.W) + self.b\n",
    "        self.cache = x  # 保存输入用于反向传播\n",
    "        return out\n",
    "    \n",
    "    def backward(self, dout):\n",
    "        x = self.cache\n",
    "        dW = np.dot(x.T, dout)\n",
    "        db = np.sum(dout, axis=0, keepdims=True)\n",
    "        dx = np.dot(dout, self.W.T)\n",
    "        return dx, dW, db\n",
    "\n",
    "# 手动实现ReLU\n",
    "class ReLU:\n",
    "    def __init__(self):\n",
    "        self.cache = None\n",
    "    \n",
    "    def forward(self, x):\n",
    "        self.cache = x\n",
    "        return np.maximum(0, x)\n",
    "    \n",
    "    def backward(self, dout):\n",
    "        x = self.cache\n",
    "        dx = dout * (x > 0)\n",
    "        return dx\n",
    "\n",
    "# 手动实现Softmax+CrossEntropy\n",
    "class SoftmaxCELoss:\n",
    "    def __init__(self):\n",
    "        self.cache = None\n",
    "    \n",
    "    def forward(self, x, y_true):\n",
    "        # Softmax\n",
    "        exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))\n",
    "        probs = exp_x / np.sum(exp_x, axis=1, keepdims=True)\n",
    "        \n",
    "        # Cross-Entropy Loss\n",
    "        loss = -np.sum(y_true * np.log(probs + 1e-12)) / x.shape[0]\n",
    "        self.cache = (probs, y_true)\n",
    "        return loss\n",
    "    \n",
    "    def backward(self):\n",
    "        probs, y_true = self.cache\n",
    "        dx = (probs - y_true) / probs.shape[0]\n",
    "        return dx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型定义\n",
    "class ManualNN:\n",
    "    def __init__(self):\n",
    "        self.fc1 = LinearLayer(28*28, 100)\n",
    "        self.relu = ReLU()\n",
    "        self.fc2 = LinearLayer(100, 10)\n",
    "        self.loss = SoftmaxCELoss()\n",
    "    \n",
    "    def forward(self, x, y_true):\n",
    "        # 前向传播\n",
    "        x = self.fc1.forward(x)\n",
    "        x = self.relu.forward(x)\n",
    "        x = self.fc2.forward(x)\n",
    "        loss = self.loss.forward(x, y_true)\n",
    "        return loss\n",
    "    \n",
    "    def backward(self):\n",
    "        # 反向传播\n",
    "        dx = self.loss.backward()\n",
    "        dx, dW2, db2 = self.fc2.backward(dx)\n",
    "        dx = self.relu.backward(dx)\n",
    "        dx, dW1, db1 = self.fc1.backward(dx)\n",
    "        return dW1, db1, dW2, db2\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实际训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50 | Loss: 2.3040 | Test Acc: 10.96%\n",
      "Epoch 2/50 | Loss: 2.3043 | Test Acc: 11.23%\n",
      "Epoch 3/50 | Loss: 2.3039 | Test Acc: 11.50%\n",
      "Epoch 4/50 | Loss: 2.3035 | Test Acc: 11.78%\n",
      "Epoch 5/50 | Loss: 2.3025 | Test Acc: 12.03%\n",
      "Epoch 6/50 | Loss: 2.3022 | Test Acc: 12.23%\n",
      "Epoch 7/50 | Loss: 2.3007 | Test Acc: 12.48%\n",
      "Epoch 8/50 | Loss: 2.3024 | Test Acc: 12.82%\n",
      "Epoch 9/50 | Loss: 2.3051 | Test Acc: 13.19%\n",
      "Epoch 10/50 | Loss: 2.3041 | Test Acc: 13.30%\n",
      "Epoch 11/50 | Loss: 2.3022 | Test Acc: 13.63%\n",
      "Epoch 12/50 | Loss: 2.3040 | Test Acc: 13.92%\n",
      "Epoch 13/50 | Loss: 2.3029 | Test Acc: 14.21%\n",
      "Epoch 14/50 | Loss: 2.3026 | Test Acc: 14.58%\n",
      "Epoch 15/50 | Loss: 2.3001 | Test Acc: 14.81%\n",
      "Epoch 16/50 | Loss: 2.3011 | Test Acc: 14.99%\n",
      "Epoch 17/50 | Loss: 2.3017 | Test Acc: 15.27%\n",
      "Epoch 18/50 | Loss: 2.3011 | Test Acc: 15.53%\n",
      "Epoch 19/50 | Loss: 2.3013 | Test Acc: 15.99%\n",
      "Epoch 20/50 | Loss: 2.3013 | Test Acc: 16.38%\n",
      "Epoch 21/50 | Loss: 2.3019 | Test Acc: 16.74%\n",
      "Epoch 22/50 | Loss: 2.3010 | Test Acc: 17.15%\n",
      "Epoch 23/50 | Loss: 2.3028 | Test Acc: 17.55%\n",
      "Epoch 24/50 | Loss: 2.3005 | Test Acc: 17.93%\n",
      "Epoch 25/50 | Loss: 2.3010 | Test Acc: 18.46%\n",
      "Epoch 26/50 | Loss: 2.3023 | Test Acc: 18.68%\n",
      "Epoch 27/50 | Loss: 2.3008 | Test Acc: 19.01%\n",
      "Epoch 28/50 | Loss: 2.2998 | Test Acc: 19.38%\n",
      "Epoch 29/50 | Loss: 2.3001 | Test Acc: 19.79%\n",
      "Epoch 30/50 | Loss: 2.3011 | Test Acc: 20.24%\n",
      "Epoch 31/50 | Loss: 2.3007 | Test Acc: 20.66%\n",
      "Epoch 32/50 | Loss: 2.3011 | Test Acc: 20.99%\n",
      "Epoch 33/50 | Loss: 2.2993 | Test Acc: 21.36%\n",
      "Epoch 34/50 | Loss: 2.2983 | Test Acc: 21.69%\n",
      "Epoch 35/50 | Loss: 2.3004 | Test Acc: 22.06%\n",
      "Epoch 36/50 | Loss: 2.3011 | Test Acc: 22.49%\n",
      "Epoch 37/50 | Loss: 2.2990 | Test Acc: 22.88%\n",
      "Epoch 38/50 | Loss: 2.2987 | Test Acc: 23.24%\n",
      "Epoch 39/50 | Loss: 2.3007 | Test Acc: 23.54%\n",
      "Epoch 40/50 | Loss: 2.3010 | Test Acc: 23.96%\n",
      "Epoch 41/50 | Loss: 2.2974 | Test Acc: 24.30%\n",
      "Epoch 42/50 | Loss: 2.3011 | Test Acc: 24.75%\n",
      "Epoch 43/50 | Loss: 2.2993 | Test Acc: 25.13%\n",
      "Epoch 44/50 | Loss: 2.2997 | Test Acc: 25.49%\n",
      "Epoch 45/50 | Loss: 2.2989 | Test Acc: 25.79%\n",
      "Epoch 46/50 | Loss: 2.2980 | Test Acc: 26.15%\n",
      "Epoch 47/50 | Loss: 2.2972 | Test Acc: 26.44%\n",
      "Epoch 48/50 | Loss: 2.2994 | Test Acc: 26.80%\n",
      "Epoch 49/50 | Loss: 2.2987 | Test Acc: 27.08%\n",
      "Epoch 50/50 | Loss: 2.2988 | Test Acc: 27.45%\n"
     ]
    }
   ],
   "source": [
    "# 训练参数\n",
    "model = ManualNN()\n",
    "learning_rate = 1e-5\n",
    "batch_size = 64\n",
    "n_epochs = 50\n",
    "\n",
    "# 训练循环\n",
    "for epoch in range(n_epochs):\n",
    "    # 随机打乱数据\n",
    "    permutation = np.random.permutation(train_data.shape[0])\n",
    "    \n",
    "    # Mini-batch训练\n",
    "    for i in range(0, train_data.shape[0], batch_size):\n",
    "        # 获取batch数据\n",
    "        indices = permutation[i:i+batch_size]\n",
    "        x_batch = train_data[indices]\n",
    "        y_batch = train_labels[indices]\n",
    "        \n",
    "        # 前向传播\n",
    "        loss = model.forward(x_batch, y_batch)\n",
    "        \n",
    "        # 反向传播计算梯度\n",
    "        dW1, db1, dW2, db2 = model.backward()\n",
    "        \n",
    "        # 手动更新参数\n",
    "        model.fc1.W -= learning_rate * dW1\n",
    "        model.fc1.b -= learning_rate * db1\n",
    "        model.fc2.W -= learning_rate * dW2\n",
    "        model.fc2.b -= learning_rate * db2\n",
    "    \n",
    "    # 每个epoch计算测试集准确率\n",
    "    # 前向传播计算预测结果\n",
    "    scores = model.fc2.forward(\n",
    "        model.relu.forward(\n",
    "            model.fc1.forward(test_data)\n",
    "        )\n",
    "    )\n",
    "    preds = np.argmax(scores, axis=1)\n",
    "    truth = np.argmax(test_labels, axis=1)\n",
    "    accuracy = np.mean(preds == truth)\n",
    "    \n",
    "    print(f\"Epoch {epoch+1}/{n_epochs} | Loss: {loss:.4f} | Test Acc: {accuracy*100:.2f}%\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
